import math
import torch
import numpy as np
import nltk
from nltk.translate.meteor_score import single_meteor_score


class GPT2LM:
    def __init__(self, use_tf=False, device=None, little=False):
        """
        :param bool use_tf: If true, uses tensorflow GPT-2 model.
        :Package Requirements:
            * **torch** (if use_tf = False)
            * **tensorflow** >= 2.0.0 (if use_tf = True)
            * **transformers**

        Language Models are Unsupervised Multitask Learners.
        `[pdf] <https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf>`__
        `[code] <https://github.com/openai/gpt-2>`__
        """
        import logging
        logging.getLogger("transformers").setLevel(logging.ERROR)
        import os
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        import transformers
        self.use_tf = use_tf
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-large")

        if use_tf:
            self.lm = transformers.TFGPT2LMHeadModel.from_pretrained("gpt2")
        else:
            self.lm = transformers.GPT2LMHeadModel.from_pretrained("gpt2-large", from_tf=False)
            self.lm.to(device)

    def __call__(self, sent):
        """
        :param str sent: A sentence.
        :return: Fluency (ppl).
        :rtype: float
        """
        if self.use_tf:
            import tensorflow as tf
            ipt = self.tokenizer(sent, return_tensors="tf", verbose=False)
            ret = self.lm(ipt)[0]
            loss = 0
            for i in range(ret.shape[0]):
                it = ret[i]
                it = it - tf.reduce_max(it, axis=1)[:, tf.newaxis]
                it = it - tf.math.log(tf.reduce_sum(tf.exp(it), axis=1))[:, tf.newaxis]
                it = tf.gather_nd(it, list(zip(range(it.shape[0] - 1), ipt.input_ids[i].numpy().tolist()[1:])))
                loss += tf.reduce_mean(it)
                break
            return math.exp(-loss)
        else:
            ipt = self.tokenizer(sent, return_tensors="pt", verbose=False, )
            # print(ipt)
            # print(ipt.input_ids)
            try:
                ppl = math.exp(self.lm(input_ids=ipt['input_ids'].cuda(),
                                       attention_mask=ipt['attention_mask'].cuda(),
                                       labels=ipt.input_ids.cuda())[0])
            except RuntimeError:
                ppl = np.nan
            return ppl


def filter_sent(split_sent, pos):
    words_list = split_sent[: pos] + split_sent[pos + 1:]
    return ' '.join(words_list)

LM = GPT2LM(use_tf=False, device='cuda' if torch.cuda.is_available() else 'cpu')


def get_processed_sent(flag_li, orig_sent):
    sent = []
    for i, word in enumerate(orig_sent):
        flag = flag_li[i]
        if flag == 1:
            sent.append(word)
    return ' '.join(sent)


def get_processed_poison_data(PPL_li, orig_split_sent, bar):
    orig_split_sent = orig_split_sent.split(' ')[:-1]
    whole_sentence_PPL = PPL_li[-1]
    processed_PPL_li = [ppl - whole_sentence_PPL for ppl in PPL_li][:-1]
    flag_li = []
    for ppl in processed_PPL_li:
        if ppl <= bar:
            flag_li.append(0)
        else:
            flag_li.append(1)
    sent = get_processed_sent(flag_li, orig_split_sent)
    return sent


def onion_format_example(sent):
    sent_pre = sent + " " +  "." + " "
    split_sent = sent_pre.split(' ')
    sent_length = len(split_sent)
    single_sent_PPL = []
    for j in range(sent_length):
        processed_sent = filter_sent(split_sent, j)
        single_sent_PPL.append(LM(processed_sent))

    test_data_poison = get_processed_poison_data(single_sent_PPL, sent_pre, bar=-10)
    return test_data_poison



def order_similarity(chat_gene, example_output):
    """
    Sorts chat_gene based on METEOR similarity between output_gene and example['output'].

    Args:
        chat_gene (list): List of dictionaries containing 'instruction_gene' and 'output_gene'.
        example_output (str): The reference sentence to compare against.

    Returns:
        list: Sorted chat_gene based on METEOR similarity in descending order.
    """
    # Ensure NLTK resources are downloaded
    nltk.download('wordnet', quiet=True)

    # Calculate METEOR similarity for each item in chat_gene
    for item in chat_gene:
        output_gene = item['output_gene']
        meteor_score = single_meteor_score(example_output.split(), output_gene.split())
        item['similarity'] = meteor_score  # Add similarity score to the item

    # Sort chat_gene by similarity in descending order
    chat_gene.sort(key=lambda x: x['similarity'], reverse=True)

    return chat_gene

